林嶔 (Lin, Chin)
Lesson 16
– 在當初的神經網路我們為了訓練一個很深的網路,常常需要對超參數做大量的嘗試修正,有時候還需要自編碼器的輔助,直到Residual Learning結束了這一切。
– 我們回顧一下我們的Cross-entropy損失函數以及他的導函數:
\[ \begin{align} CE(y, p) & = \frac{{1}}{n}\sum \limits_{i=1}^{n} -\left(y_{i} \cdot log(p_{i}) + (1-y_{i}) \cdot log(1-p_{i})\right) \\ \frac{\partial}{\partial p}CE(y, p) & = \frac{p-y}{p(1-p)} \end{align} \]
\[ \begin{align} S(x) & =\frac{1}{1+e^{-x}} \\ \frac{\partial}{\partial x}S(x) & = S(x)(1-S(x)) \end{align} \]
– 但對於Generator而言呢,他這時候需要大量的更新試圖重新騙過Discriminator,但這時候他的梯度將是…
\[ \begin{align} \lim_{p \rightarrow 1} CE(0, p) & = - log(1-p) \\ \frac{\partial}{\partial p} \lim_{p \rightarrow 1} CE(0, p) & = \frac{p}{p(1-p)} \end{align} \]
\[ \begin{align} \frac{\partial}{\partial x} \lim_{S(x) \rightarrow 1} CE(0, p) & = \frac{S(x)^2(1-S(x))}{S(x)(1-S(x))} \end{align} \]
– 除此之外,對於Discriminator以及一般網路隨著\(p \rightarrow 1\)的過程中,他的梯度會慢慢變小,這有點學習率遞減的概念,但此時對於Generator而言卻是學習率遞增的,而學習率不見得比較大就收斂比較快!
去除Sigmoid函數,因為它會在極端狀況下導致近似值計算失準。
無論Discriminator跟Generator誰佔優勢,選擇一個平滑的損失函數來描述目前的競賽狀況。
註:數值越小代表越好!
\[ \begin{align} loss(y, x) & = (1-y)x - yx \end{align} \]
這個函數的意義是,當\(y=0\)時,那他希望\(x \rightarrow -\infty\),而\(y=1\)時,那他希望\(x \rightarrow \infty\)。
除此之外,他還對Discriminator的Weight做了一個特殊的操作,那就是限制Discriminator中所有的Weight不超過某個數(一般訂為0.1)。
讓我們在MNIST上實現WGAN,在開始前同樣先下載上節課中用的train_data.csv、sub_train_data.csv以及test_data.csv。
我們這裡延續上節課的Conditional GAN,而在Iterator的部分完全沒有改變(現在只有把雜訊標籤給去除掉):
library(imager)
library(magrittr)
library(mxnet)
my_iterator_func <- setRefClass("Custom_Iter",
fields = c("iter", "data.csv", "data.shape", "batch.size"),
contains = "Rcpp_MXArrayDataIter",
methods = list(
initialize = function(iter, data.csv, data.shape, batch.size){
csv_iter <- mx.io.CSVIter(data.csv = data.csv, data.shape = data.shape, batch.size = batch.size)
.self$iter <- csv_iter
.self
},
value = function(){
val <- as.array(.self$iter$value()$data)
val.x <- val[-1,]
batch_size <- ncol(val.x)
val.x <- val.x / 255 # Important
dim(val.x) <- c(28, 28, 1, batch_size)
val.x <- mx.nd.array(val.x)
digit.real <- mx.nd.array(val[1,])
digit.real <- mx.nd.one.hot(indices = digit.real, depth = 10)
digit.real <- mx.nd.reshape(data = digit.real, shape = c(1, 1, -1, batch_size))
digit.fake <- mx.nd.array(sample(0:9, size = batch_size, replace = TRUE))
digit.fake <- mx.nd.one.hot(indices = digit.fake, depth = 10)
digit.fake <- mx.nd.reshape(data = digit.fake, shape = c(1, 1, -1, batch_size))
rand <- rnorm(batch_size * 10, mean = 0, sd = 1)
rand <- array(rand, dim = c(1, 1, 10, batch_size))
rand <- mx.nd.array(rand)
label.real <- array(runif(10, 0, 0), dim = c(1, 1, 1, batch_size))
label.real <- mx.nd.array(label.real)
label.fake <- array(runif(10, 1, 1), dim = c(1, 1, 1, batch_size))
label.fake <- mx.nd.array(label.fake)
label.gen <- array(rep(0, 10), dim = c(1, 1, 1, batch_size))
label.gen <- mx.nd.array(label.gen)
list(noise = rand, img = val.x, digit.fake = digit.fake, digit.real = digit.real, label.fake = label.fake, label.real = label.real, label.gen = label.gen)
},
iter.next = function(){
.self$iter$iter.next()
},
reset = function(){
.self$iter$reset()
},
finalize=function(){
}
)
)
my_iter <- my_iterator_func(iter = NULL, data.csv = 'data/train_data.csv', data.shape = 785, batch.size = 32)
gen_data <- mx.symbol.Variable('data')
gen_digit <- mx.symbol.Variable('digit')
gen_concat <- mx.symbol.concat(data = list(gen_data, gen_digit), num.args = 2, dim = 1, name = "gen_concat")
gen_deconv1 <- mx.symbol.Deconvolution(data = gen_concat, kernel = c(4, 4), stride = c(2, 2), num_filter = 256, name = 'gen_deconv1')
gen_bn1 <- mx.symbol.BatchNorm(data = gen_deconv1, fix_gamma = TRUE, name = 'gen_bn1')
gen_relu1 <- mx.symbol.Activation(data = gen_bn1, act_type = "relu", name = 'gen_relu1')
gen_deconv2 <- mx.symbol.Deconvolution(data = gen_relu1, kernel = c(3, 3), stride = c(2, 2), pad = c(1, 1), num_filter = 128, name = 'gen_deconv2')
gen_bn2 <- mx.symbol.BatchNorm(data = gen_deconv2, fix_gamma = TRUE, name = 'gen_bn2')
gen_relu2 <- mx.symbol.Activation(data = gen_bn2, act_type = "relu", name = 'gen_relu2')
gen_deconv3 <- mx.symbol.Deconvolution(data = gen_relu2, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 64, name = 'gen_deconv3')
gen_bn3 <- mx.symbol.BatchNorm(data = gen_deconv3, fix_gamma = TRUE, name = 'gen_bn3')
gen_relu3 <- mx.symbol.Activation(data = gen_bn3, act_type = "relu", name = 'gen_relu3')
gen_deconv4 <- mx.symbol.Deconvolution(data = gen_relu3, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 1, name = 'gen_deconv4')
gen_pred <- mx.symbol.Activation(data = gen_deconv4, act_type = "sigmoid", name = 'gen_pred')
dis_img <- mx.symbol.Variable('img')
dis_digit <- mx.symbol.Variable("digit")
dis_label <- mx.symbol.Variable('label')
dis_concat <- mx.symbol.broadcast_mul(lhs = dis_img, rhs = dis_digit, name = 'dis_concat')
dis_conv1 <- mx.symbol.Convolution(data = dis_concat, kernel = c(3, 3), num_filter = 24, no.bias = TRUE, name = 'dis_conv1')
dis_bn1 <- mx.symbol.BatchNorm(data = dis_conv1, fix_gamma = TRUE, name = 'dis_bn1')
dis_relu1 <- mx.symbol.LeakyReLU(data = dis_bn1, act_type = "leaky", slope = 0.2, name = "dis_relu1")
dis_pool1 <- mx.symbol.Pooling(data = dis_relu1, pool_type = "avg", kernel = c(2, 2), stride = c(2, 2), name = 'dis_pool1')
dis_conv2 <- mx.symbol.Convolution(data = dis_pool1, kernel = c(3, 3), stride = c(2, 2), num_filter = 32, no.bias = TRUE, name = 'dis_conv2')
dis_bn2 <- mx.symbol.BatchNorm(data = dis_conv2, fix_gamma = TRUE, name = 'dis_bn2')
dis_relu2 <- mx.symbol.LeakyReLU(data = dis_bn2, act_type = "leaky", slope = 0.2, name = "dis_relu2")
dis_conv3 <- mx.symbol.Convolution(data = dis_relu2, kernel = c(3, 3), num_filter = 64, no.bias = TRUE, name = 'dis_conv3')
dis_bn3 <- mx.symbol.BatchNorm(data = dis_conv3, fix_gamma = TRUE, name = 'dis_bn3')
dis_relu3 <- mx.symbol.LeakyReLU(data = dis_bn3, act_type = "leaky", slope = 0.2, name = "dis_relu3")
dis_conv4 <- mx.symbol.Convolution(data = dis_relu3, kernel = c(4, 4), num_filter = 64, no.bias = TRUE, name = 'dis_conv4')
dis_bn4 <- mx.symbol.BatchNorm(data = dis_conv4, fix_gamma = TRUE, name = 'dis_bn4')
dis_relu4 <- mx.symbol.LeakyReLU(data = dis_bn4, act_type = "leaky", slope = 0.2, name = "dis_relu4")
dis_pred <- mx.symbol.Convolution(data = dis_relu4, kernel = c(1, 1), num_filter = 1, name = 'dis_pred')
w_loss_pos <- mx.symbol.broadcast_mul(dis_pred, dis_label)
w_loss_neg <- mx.symbol.broadcast_mul(dis_pred, 1 - dis_label)
w_loss_mean <- mx.symbol.mean(w_loss_neg - w_loss_pos)
w_loss <- mx.symbol.MakeLoss(w_loss_mean, name = 'w_loss')
gen_optimizer <- mx.opt.create(name = "adam", learning.rate = 1e-4, beta1 = 0, beta2 = 0.9, wd = 0)
dis_optimizer <- mx.opt.create(name = "adam", learning.rate = 1e-4, beta1 = 0, beta2 = 0.9, wd = 0)
gen_executor <- mx.simple.bind(symbol = gen_pred,
data = c(1, 1, 10, 32), digit = c(1, 1, 10, 32),
ctx = mx.cpu(), grad.req = "write")
dis_executor <- mx.simple.bind(symbol = w_loss,
img = c(28, 28, 1, 32), digit = c(1, 1, 10, 32), label = c(1, 1, 1, 32),
ctx = mx.cpu(), grad.req = "write")
# Initial parameters
mx.set.seed(0)
gen_arg <- mxnet:::mx.model.init.params(symbol = gen_pred,
input.shape = list(data = c(1, 1, 10, 32), digit = c(1, 1, 10, 32)),
output.shape = NULL,
initializer = mxnet:::mx.init.uniform(0.01),
ctx = mx.cpu())
dis_arg <- mxnet:::mx.model.init.params(symbol = w_loss,
input.shape = list(img = c(28, 28, 1, 32), digit = c(1, 1, 10, 32), label = c(1, 1, 1, 32)),
output.shape = NULL,
initializer = mxnet:::mx.init.uniform(0.01),
ctx = mx.cpu())
# Update parameters
mx.exec.update.arg.arrays(gen_executor, gen_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(gen_executor, gen_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(dis_executor, dis_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(dis_executor, dis_arg$aux.params, match.name = TRUE)
gen_updater <- mx.opt.get.updater(optimizer = gen_optimizer, weights = gen_executor$ref.arg.arrays)
dis_updater <- mx.opt.get.updater(optimizer = dis_optimizer, weights = dis_executor$ref.arg.arrays)
set.seed(0)
n.epoch <- 20
w_limit <- 0.1
logger <- list(gen_loss = NULL, dis_real_loss = NULL, dis_fake_loss = NULL)
for (j in 1:n.epoch) {
current_batch <- 0
my_iter$reset()
while (my_iter$iter.next()) {
my_values <- my_iter$value()
# Generator (forward)
mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = my_values[['noise']], digit = my_values[['digit.fake']]), match.name = TRUE)
mx.exec.forward(gen_executor, is.train = TRUE)
gen_pred_output <- gen_executor$ref.outputs[[1]]
# Discriminator (fake)
mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, digit = my_values[['digit.fake']], label = my_values[['label.fake']]), match.name = TRUE)
mx.exec.forward(dis_executor, is.train = TRUE)
mx.exec.backward(dis_executor)
dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
logger$dis_fake_loss <- c(logger$dis_fake_loss, as.array(dis_executor$ref.outputs[[1]]))
# Discriminator (real)
mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = my_values[['img']], digit = my_values[['digit.real']], label = my_values[['label.real']]), match.name = TRUE)
mx.exec.forward(dis_executor, is.train = TRUE)
mx.exec.backward(dis_executor)
dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
logger$dis_real_loss <- c(logger$dis_real_loss, as.array(dis_executor$ref.outputs[[1]]))
# Weight clipping (only for discriminator)
dis_weight_names <- grep('weight', names(dis_executor$ref.arg.arrays), value = TRUE)
for (k in dis_weight_names) {
current_dis_weight <- dis_executor$ref.arg.arrays[[k]] %>% as.array()
current_dis_weight_list <- current_dis_weight %>% mx.nd.array() %>%
mx.nd.broadcast.minimum(., mx.nd.array(w_limit)) %>%
mx.nd.broadcast.maximum(., mx.nd.array(-w_limit)) %>%
list()
names(current_dis_weight_list) <- k
mx.exec.update.arg.arrays(dis_executor, arg.arrays = current_dis_weight_list, match.name = TRUE)
}
# Generator (backward)
mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, digit = my_values[['digit.fake']], label = my_values[['label.gen']]), match.name = TRUE)
mx.exec.forward(dis_executor, is.train = TRUE)
mx.exec.backward(dis_executor)
img_grads <- dis_executor$ref.grad.arrays[['img']]
mx.exec.backward(gen_executor, out_grads = img_grads)
gen_update_args <- gen_updater(weight = gen_executor$ref.arg.arrays, grad = gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(gen_executor, gen_update_args, skip.null = TRUE)
logger$gen_loss <- c(logger$gen_loss, as.array(dis_executor$ref.outputs[[1]]))
if (current_batch %% 100 == 0) {
# Show current images
current_digits <- my_values[['digit.fake']] %>% as.array() %>% .[,,,1:9] %>% t %>% max.col - 1
par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
for (i in 1:9) {
img <- as.array(gen_pred_output)[,,,i]
plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
text(0.05, 0.95, current_digits[i], col = 'green', cex = 2)
}
# Show loss
message('Epoch [', j, '] Batch [', current_batch, '] Generator-loss = ', formatC(tail(logger$gen_loss, 1), digits = 5, format = 'f'))
message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (real) = ', formatC(tail(logger$dis_real_loss, 1), digits = 5, format = 'f'))
message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (fake) = ', formatC(tail(logger$dis_fake_loss, 1), digits = 5, format = 'f'))
}
current_batch <- current_batch + 1
}
pdf(paste0('result/epoch_', j, '.pdf'), height = 6, width = 6)
current_digits <- my_values[['digit.fake']] %>% as.array() %>% .[,,,1:9] %>% t %>% max.col - 1
par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
for (i in 1:9) {
img <- as.array(gen_pred_output)[,,,i]
plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
text(0.05, 0.95, current_digits[i], col = 'green', cex = 2)
}
dev.off()
gen_model <- list()
gen_model$symbol <- gen_pred
gen_model$arg.params <- gen_executor$ref.arg.arrays[-c(1:2)]
gen_model$aux.params <- gen_executor$ref.aux.arrays
class(gen_model) <- "MXFeedForwardModel"
dis_model <- list()
dis_model$symbol <- dis_pred
dis_model$arg.params <- dis_executor$ref.arg.arrays[-c(1:2)]
dis_model$aux.params <- dis_executor$ref.aux.arrays
class(dis_model) <- "MXFeedForwardModel"
mx.model.save(model = gen_model, prefix = 'model/cwgen_v1', iteration = j)
mx.model.save(model = dis_model, prefix = 'model/cwdis_v1', iteration = j)
}
range_logger <- logger %>% unlist %>% range
plot(logger$gen_loss, type = 'l', col = 'red', lwd = 0.5, ylim = range_logger, xlab = 'Batch', ylab = 'loss')
lines(1:length(logger$dis_real_loss), logger$dis_real_loss, col = 'blue', lwd = 0.5)
lines(1:length(logger$dis_fake_loss), logger$dis_fake_loss, col = 'darkgreen', lwd = 0.5)
legend('topright', c('Gen', 'Real', 'Fake'), col = c('red', 'blue', 'darkgreen'), lwd = 1)
cwgen_model <- mx.model.load('model/cwgen_v1', 0)
my_predict <- function (model, digits = 0:9) {
batch_size <- length(digits)
gen_executor <- mx.simple.bind(symbol = model$symbol,
data = c(1, 1, 10, batch_size), digit = c(1, 1, 10, batch_size),
ctx = mx.cpu())
mx.exec.update.arg.arrays(gen_executor, model$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(gen_executor, model$aux.params, match.name = TRUE)
noise_array <- rnorm(batch_size * 10, mean = 0, sd = 1)
noise_array <- array(noise_array, dim = c(1, 1, 10, batch_size))
noise_array <- mx.nd.array(noise_array)
digit_array <- mx.nd.array(digits)
digit_array <- mx.nd.one.hot(indices = digit_array, depth = 10)
digit_array <- mx.nd.reshape(data = digit_array, shape = c(1, 1, -1, batch_size))
mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = noise_array, digit = digit_array), match.name = TRUE)
mx.exec.forward(gen_executor, is.train = FALSE)
gen_pred_output <- gen_executor$ref.outputs[[1]]
return(as.array(gen_pred_output))
}
pred_img <- my_predict(model = cwgen_model, digits = 0:9)
par(mfrow = c(2, 5), mar = c(0.1, 0.1, 0.1, 0.1))
for (i in 1:10) {
img <- pred_img[,,,i]
plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
}
– 這一系列的GAN,說穿了就是改改損失函數,而LSGAN所使用的是平方誤差函數,損失函數被改為:
\[ \begin{align} loss(y, x) & = (x-y)^2 \end{align} \]
\(a = -1\)、\(b = 1\)、\(c = 0\)
\(a = 0\)、\(b = 1\)、\(c = 1\)
loss_diff <- mx.symbol.broadcast_minus(dis_pred, dis_label)
loss_square_diff <- mx.symbol.square(loss_diff)
loss_mean <- mx.symbol.mean(loss_square_diff)
ls_loss <- mx.symbol.MakeLoss(loss_mean, name = 'ls_loss')
my_iterator_func <- setRefClass("Custom_Iter",
fields = c("iter", "data.csv", "data.shape", "batch.size"),
contains = "Rcpp_MXArrayDataIter",
methods = list(
initialize = function(iter, data.csv, data.shape, batch.size){
csv_iter <- mx.io.CSVIter(data.csv = data.csv, data.shape = data.shape, batch.size = batch.size)
.self$iter <- csv_iter
.self
},
value = function(){
val <- as.array(.self$iter$value()$data)
val.x <- val[-1,]
batch_size <- ncol(val.x)
val.x <- val.x / 255 # Important
dim(val.x) <- c(28, 28, 1, batch_size)
val.x <- mx.nd.array(val.x)
digit.real <- mx.nd.array(val[1,])
digit.real <- mx.nd.one.hot(indices = digit.real, depth = 10)
digit.real <- mx.nd.reshape(data = digit.real, shape = c(1, 1, -1, batch_size))
digit.fake <- mx.nd.array(sample(0:9, size = batch_size, replace = TRUE))
digit.fake <- mx.nd.one.hot(indices = digit.fake, depth = 10)
digit.fake <- mx.nd.reshape(data = digit.fake, shape = c(1, 1, -1, batch_size))
rand <- rnorm(batch_size * 10, mean = 0, sd = 1)
rand <- array(rand, dim = c(1, 1, 10, batch_size))
rand <- mx.nd.array(rand)
label.real <- array(rep(0, 10), dim = c(1, 1, 1, batch_size))
label.real <- mx.nd.array(label.real)
label.fake <- array(rep(1, 10), dim = c(1, 1, 1, batch_size))
label.fake <- mx.nd.array(label.fake)
label.gen <- array(rep(1, 10), dim = c(1, 1, 1, batch_size))
label.gen <- mx.nd.array(label.gen)
list(noise = rand, img = val.x, digit.fake = digit.fake, digit.real = digit.real, label.fake = label.fake, label.real = label.real, label.gen = label.gen)
},
iter.next = function(){
.self$iter$iter.next()
},
reset = function(){
.self$iter$reset()
},
finalize=function(){
}
)
)
my_iter <- my_iterator_func(iter = NULL, data.csv = 'data/train_data.csv', data.shape = 785, batch.size = 32)
– 這是Jun-Yan Zhu、Taesung Park、Phillip Isola與Alexei A. Efros在2017年提出的研究:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks中所提出的模型
– 但比較厲害的是,Cycle GAN所輸入的條件並非是類似於Conditional GAN的結構化條件,而是直接利用一張圖片當作Condition。
– 在許多領域中,我們不可能蒐集到Paired data,舉例來說像是照片與藝術畫的轉換,因此這件事情是非常重要的!除此之外,Paired data的蒐集難度也高上非常多,這也是Cycle GAN所吸引人的地方所在。
– 假定有兩個函數分別是\(G(X)\)負責將\(X\)轉換為\(\hat{Y}\),而另一個函數\(F(Y)\)負責將\(Y\)轉換為\(\hat{X}\),則當兩個函數達到完美狀態時,必須保證\(F(G(X)) = X\)且\(G(F(Y)) = Y\)。
\[ \begin{align} \mbox{cycle consistency loss} & = |F(G(X)) - X| + |G(F(Y)) - Y| \end{align} \]
註:原始論文使用的是L1 loss(如上式),而替換成L2 loss或者是其他損失函數影響不大。
– 這時候我們要引入兩個Discriminator分別是\(D_x(X)\)以及\(D_y(Y)\),他們分別要識別翻譯出來的圖是真的還是假的,而在Generator跟Discriminator的競合我們將給出一個對抗損失(adversarial loss),這與之前的所有GAN完全一樣,我們當然也能用WGAN的損失函數:
\[ \begin{align} \mbox{adversarial loss for Discriminator(x)} & = D_x(F(Y)) - D_x(X) \\ \mbox{adversarial loss for Discriminator(y)} & = D_y(G(X)) - D_y(Y) \\ \mbox{adversarial loss for Generator(x)} & = - D_x(F(Y)) \\ \mbox{adversarial loss for Generator(y)} & = - D_y(G(X)) \end{align} \]
– 這是原始照片:
library(OpenImageR)
library(jpeg)
photo <- readJPEG('images/header.jpg')
resize_photo <- resizeImage(image = photo,
width = 648,
height = 256,
method = "bilinear")
Show_img <- function (img) {
par(mai = rep(0, 4))
plot(NA, xlim = c(0.04, 0.96), ylim = c(0.96, 0.04), xaxt = "n", yaxt = "n", bty = "n")
rasterImage(as.raster(img), 0, 1, 1, 0, interpolate = FALSE)
}
Show_img(resize_photo)
library(mxnet)
P2M_gen_model <- mx.model.load(prefix = 'model/P2M_gen_v1', iteration = 0)
my_predict <- function (model, img) {
dim(img) <- c(dim(img), 1)
P2M_executor <- mx.simple.bind(symbol = model$symbol,
P2M_img = dim(img),
ctx = mx.cpu())
mx.exec.update.arg.arrays(P2M_executor, model$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(P2M_executor, model$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(P2M_executor, arg.arrays = list(P2M_img = mx.nd.array(img)), match.name = TRUE)
mx.exec.forward(P2M_executor, is.train = FALSE)
P2M_pred_output <- P2M_executor$ref.outputs[[1]]
return(as.array(P2M_pred_output)[,,,1])
}
monet_img <- my_predict(model = P2M_gen_model, img = resize_photo)
Show_img(monet_img)
– 為了解決這個問題,我們還需要額外發展出一個identity mapping loss來解決這個問題!
– 我們依此定義identity mapping loss如下(這邊要注意的是,函數\(G\)原來是負責\(X \rightarrow Y\)的,而函數\(F\)原來則是負責\(Y \rightarrow X\)):
\[ \begin{align} \mbox{identity mapping loss} & = |F(X) - X| + |G(Y) - Y| \end{align} \]
P2M_gen_model <- mx.model.load(prefix = 'model/P2M_gen_v2', iteration = 0)
monet_img <- my_predict(model = P2M_gen_model, img = resize_photo)
Show_img(monet_img)
– 讓我們看看效果吧!這是再過一次的效果(共計2次):
monet_img <- my_predict(model = P2M_gen_model, img = monet_img)
Show_img(monet_img)
– 再過一次試試看(共計3次):
monet_img <- my_predict(model = P2M_gen_model, img = monet_img)
Show_img(monet_img)
– 其實不用特別看大概也知道,由於Input size等於Output size,所以它的結構會與 前做Segmentation時的網路類似。
接著讓我們練習來實現一整個CycleGAN,但由於運算資源的問題我們使用的是閹割版的模型(完整的訓練過程請參考xup6fup/MxNetR-CycleGAN),這裡我們只使用5張莫內繪畫+5張真實照片進行實驗,並且我們將檔案縮小為64×64,實驗檔案請在這裡下載。
我們先載入套件及指定模型參數(你可以試著把n.epoch加長):
# Libraries
library(mxnet)
library(imager)
library(jpeg)
library(magrittr)
# Parameters
CTX <- mx.cpu()
Batch_size <- 1
num_show_img <- 1
n.epoch <- 10
n.print <- 20
w_limit <- 0.1
learning_rate <- 1e-4
lambda_cycle_consistency_loss <- 10
lambda_identity_mapping_loss <- 5
model_name <- 'mini'
# Load data
load('data/mini_train_list.RData')
# Iterator function
my_iterator_core <- function (batch_size) {
batch <- 0
batch_per_epoch <- floor(length(train_list[[2]])/batch_size)
reset <- function() {batch <<- 0}
iter.next <- function() {
batch <<- batch + 1
if (batch > batch_per_epoch) {return(FALSE)} else {return(TRUE)}
}
value <- function() {
idx <- 1:batch_size + (batch - 1) * batch_size
img_array.1 <- array(0, dim = c(64, 64, 3, batch_size))
img_array.2 <- array(0, dim = c(64, 64, 3, batch_size))
for (i in 1:batch_size) {
img_array.1[,,,i] <- readJPEG(train_list[[1]][[idx[i]]])
img_array.2[,,,i] <- readJPEG(train_list[[2]][[idx[i]]])
}
img_array.1 <- mx.nd.array(img_array.1)
img_array.2 <- mx.nd.array(img_array.2)
return(list(monet = img_array.1, photo = img_array.2))
}
return(list(reset = reset, iter.next = iter.next, value = value, batch_size = batch_size, batch = batch))
}
my_iterator_func <- setRefClass("Custom_Iter",
fields = c("iter", "batch_size"),
contains = "Rcpp_MXArrayDataIter",
methods = list(
initialize = function(iter, batch_size = 16){
.self$iter <- my_iterator_core(batch_size = batch_size)
.self
},
value = function(){
.self$iter$value()
},
iter.next = function(){
.self$iter$iter.next()
},
reset = function(){
.self$iter$reset()
},
finalize=function(){
}
)
)
# Build an iterator
my_iter <- my_iterator_func(iter = NULL, batch_size = Batch_size)
# Show image function
Show_img <- function (img) {
plot(NA, xlim = c(0.04, 0.96), ylim = c(0.96, 0.04), xaxt = "n", yaxt = "n", bty = "n")
rasterImage(as.raster(img), 0, 1, 1, 0, interpolate = FALSE)
}
# Test the iterator
my_iter$reset()
my_iter$iter.next()
## [1] TRUE
test_data <- my_iter$value()
par(mai = rep(0, 4))
Show_img(as.array(test_data[[1]])[,,,1])
Residual.CONV_Module <- function (indata, num_filters = 128, kernel_size = 3, relu_slope = 0, name = 'g1', stage = 1) {
Conv.1 <- mx.symbol.Convolution(data = indata, kernel = c(kernel_size, kernel_size), stride = c(1, 1),
pad = c((kernel_size - 1)/2, (kernel_size - 1)/2),
no.bias = TRUE, num.filter = num_filters,
name = paste0(name, '_', stage, '_Conv.1'))
InstNorm.1 <- mx.symbol.InstanceNorm(data = Conv.1, name = paste0(name, '_', stage, '_InstNorm.1'))
ReLU.1 <- mx.symbol.LeakyReLU(data = InstNorm.1, act.type = 'leaky', slope = relu_slope, name = paste0(name, '_', stage, '_ReLU.1'))
Conv.2 <- mx.symbol.Convolution(data = ReLU.1, kernel = c(kernel_size, kernel_size), stride = c(1, 1),
pad = c((kernel_size - 1)/2, (kernel_size - 1)/2),
no.bias = TRUE, num.filter = num_filters,
name = paste0(name, '_', stage, '_Conv.2'))
InstNorm.2 <- mx.symbol.InstanceNorm(data = Conv.2, name = paste0(name, '_', stage, '_InstNorm.2'))
ReLU.2 <- mx.symbol.LeakyReLU(data = InstNorm.2, act.type = 'leaky', slope = relu_slope, name = paste0(name, '_', stage, '_ReLU.2'))
ResBlock <- mx.symbol.broadcast_plus(lhs = indata, rhs = ReLU.2, name = paste0(name, '_', stage, '_ResBlock'))
return(ResBlock)
}
general.CONV_Module <- function (indata, num_filters = 128, kernel_size = 3, stride = 1, pad = 1, relu_slope = 0, drop_p = 0, name = 'g1', stage = 1, normalization = FALSE) {
Drop <- mx.symbol.Dropout(data = indata, p = drop_p, name = paste0(name, '_', stage, '_Drop'))
if (normalization) {
Conv <- mx.symbol.Convolution(data = Drop, kernel = c(kernel_size, kernel_size), stride = c(stride, stride),
pad = c(pad, pad),
no.bias = TRUE, num.filter = num_filters,
name = paste0(name, '_', stage, '_Conv'))
InstNorm <- mx.symbol.InstanceNorm(data = Conv, name = paste0(name, '_', stage, '_InstNorm'))
ReLU <- mx.symbol.LeakyReLU(data = InstNorm, act.type = 'leaky', slope = relu_slope, name = paste0(name, '_', stage, '_ReLU'))
return(ReLU)
} else {
Conv <- mx.symbol.Convolution(data = Drop, kernel = c(kernel_size, kernel_size), stride = c(stride, stride),
pad = c(pad, pad),
no.bias = FALSE, num.filter = num_filters,
name = paste0(name, '_', stage, '_Conv'))
return(Conv)
}
}
DECONV_Module <- function (indata, updata = NULL, num_filters = 128, relu_slope = 0, name = 'g1', stage = 1) {
DeConv <- mx.symbol.Deconvolution(data = indata, kernel = c(2, 2), stride = c(2, 2),
num_filter = num_filters,
name = paste0(name, '_', stage, '_DeConv'))
InstNorm <- mx.symbol.InstanceNorm(data = DeConv, name = paste0(name, '_', stage, '_InstNorm'))
ReLU <- mx.symbol.LeakyReLU(data = InstNorm, act.type = 'leaky', slope = relu_slope, name = paste0(name, '_', stage, '_ReLU'))
if (is.null(updata)) {
return(ReLU)
} else {
DenBlock <- mx.symbol.concat(data = list(updata, ReLU), num.args = 2, dim = 1, name = paste0(name, '_', stage, '_DenBlock'))
return(DenBlock)
}
}
Generator_symbol <- function (name = 'g1') {
g_img <- mx.symbol.Variable(paste0(name, '_img'))
g_1 <- general.CONV_Module(indata = g_img, num_filters = 8, kernel_size = 7, stride = 1, pad = 3, relu_slope = 0, drop_p = 0, name = name, stage = 1, normalization = TRUE)
g_2 <- general.CONV_Module(indata = g_1, num_filters = 16, kernel_size = 3, stride = 2, pad = 1, relu_slope = 0, drop_p = 0, name = name, stage = 2, normalization = TRUE)
g_3 <- general.CONV_Module(indata = g_2, num_filters = 32, kernel_size = 3, stride = 2, pad = 1, relu_slope = 0, drop_p = 0, name = name, stage = 3, normalization = TRUE)
g_4 <- Residual.CONV_Module(indata = g_3, num_filters = 32, kernel_size = 3, relu_slope = 0, name = name, stage = 4)
g_5 <- Residual.CONV_Module(indata = g_4, num_filters = 32, kernel_size = 3, relu_slope = 0, name = name, stage = 5)
g_6 <- DECONV_Module(indata = g_5, updata = g_2, num_filters = 16, relu_slope = 0, name = name, stage = 6)
g_7 <- DECONV_Module(indata = g_6, updata = g_1, num_filters = 8, relu_slope = 0, name = name, stage = 7)
g_8 <- general.CONV_Module(indata = g_7, num_filters = 3, kernel_size = 7, stride = 1, pad = 3, relu_slope = 0, drop_p = 0, name = name, stage = 8, normalization = FALSE)
g_pred <- mx.symbol.Activation(data = g_8, act_type = "sigmoid", name = paste0(name, '_pred'))
return(g_pred)
}
Discriminator_symbol <- function (name = 'd1', drop_p = 0) {
d_img <- mx.symbol.Variable(paste0(name, '_img'))
d_1 <- general.CONV_Module(indata = d_img, num_filters = 8, kernel_size = 4, stride = 2, pad = 0, relu_slope = 0.2, drop_p = drop_p, name = name, stage = 1, normalization = TRUE)
d_2 <- general.CONV_Module(indata = d_1, num_filters = 16, kernel_size = 4, stride = 2, pad = 0, relu_slope = 0.2, drop_p = drop_p, name = name, stage = 2, normalization = TRUE)
d_3 <- general.CONV_Module(indata = d_2, num_filters = 32, kernel_size = 4, stride = 2, pad = 0, relu_slope = 0.2, drop_p = drop_p, name = name, stage = 3, normalization = TRUE)
d_4 <- general.CONV_Module(indata = d_3, num_filters = 64, kernel_size = 4, stride = 2, pad = 0, relu_slope = 0.2, drop_p = drop_p, name = name, stage = 4, normalization = TRUE)
d_5 <- general.CONV_Module(indata = d_4, num_filters = 1, kernel_size = 1, stride = 1, pad = 0, relu_slope = 0, drop_p = drop_p, name = name, stage = 5, normalization = FALSE)
d_pred <- mx.symbol.mean(data = d_5, axis = 1:3, keepdims = FALSE, name = paste0(name, '_pred'))
return(d_pred)
}
adversarial_loss <- function (pred, label, lambda = 1) {
loss_pos <- mx.symbol.broadcast_mul(pred, label)
loss_neg <- mx.symbol.broadcast_mul(pred, 1 - label)
loss_mean <- mx.symbol.mean(loss_neg - loss_pos)
weighted_loss_mean <- loss_mean * lambda
adversarial_loss <- mx.symbol.MakeLoss(weighted_loss_mean)
return(adversarial_loss)
}
cycle_consistency_loss <- function (pred, label, lambda = 10) {
diff_pred_label <- mx.symbol.broadcast_minus(lhs = pred, rhs = label)
abs_diff_pred_label <- mx.symbol.abs(data = diff_pred_label)
mean_loss <- mx.symbol.mean(data = abs_diff_pred_label, axis = 0:3, keepdims = FALSE)
weighted_mean_loss <- mean_loss * lambda
cycle_consistency_loss <- mx.symbol.MakeLoss(weighted_mean_loss)
return(cycle_consistency_loss)
}
identity_mapping_loss <- function (pred, label, lambda = 5) {
diff_pred_label <- mx.symbol.broadcast_minus(lhs = pred, rhs = label)
abs_diff_pred_label <- mx.symbol.abs(data = diff_pred_label)
mean_loss <- mx.symbol.mean(data = abs_diff_pred_label, axis = 0:3, keepdims = FALSE)
weighted_mean_loss <- mean_loss * lambda
cycle_consistency_loss <- mx.symbol.MakeLoss(weighted_mean_loss)
return(cycle_consistency_loss)
}
# Generator-1 (Monet to Photo)
M2P_gen <- Generator_symbol(name = 'M2P')
# Generator-2 (Photo to Monet)
P2M_gen <- Generator_symbol(name = 'P2M')
# Discriminator-1 (Monet)
Monet_dis <- Discriminator_symbol(name = 'Monet', drop_p = 0)
# Discriminator-2 (Photo)
Photo_dis <- Discriminator_symbol(name = 'Photo', drop_p = 0)
# adversarial loss-1 (Monet)
label <- mx.symbol.Variable('label')
Monet_loss <- adversarial_loss(pred = Monet_dis, label = label, lambda = 1)
# adversarial loss-2 (Photo)
label <- mx.symbol.Variable('label')
Photo_loss <- adversarial_loss(pred = Photo_dis, label = label, lambda = 1)
# cycle consistency loss
pred <- mx.symbol.Variable('pred')
label <- mx.symbol.Variable('label')
CC_loss <- cycle_consistency_loss(pred = pred, label = label, lambda = lambda_cycle_consistency_loss)
# identity mapping loss
pred <- mx.symbol.Variable('pred')
label <- mx.symbol.Variable('label')
IM_loss <- identity_mapping_loss(pred = pred, label = label, lambda = lambda_identity_mapping_loss)
M2P_gen_executor <- mx.simple.bind(symbol = M2P_gen,
M2P_img = c(64, 64, 3, Batch_size),
ctx = CTX, grad.req = "write")
P2M_gen_executor <- mx.simple.bind(symbol = P2M_gen,
P2M_img = c(64, 64, 3, Batch_size),
ctx = CTX, grad.req = "write")
Monet_dis_executor <- mx.simple.bind(symbol = Monet_loss,
Monet_img = c(64, 64, 3, Batch_size), label = c(Batch_size),
ctx = CTX, grad.req = "write")
Photo_dis_executor <- mx.simple.bind(symbol = Photo_loss,
Photo_img = c(64, 64, 3, Batch_size), label = c(Batch_size),
ctx = CTX, grad.req = "write")
cycle_consistency_executor <- mx.simple.bind(symbol = CC_loss,
pred = c(64, 64, 3, Batch_size), label = c(64, 64, 3, Batch_size),
ctx = CTX, grad.req = "write")
identity_mapping_executor <- mx.simple.bind(symbol = IM_loss,
pred = c(64, 64, 3, Batch_size), label = c(64, 64, 3, Batch_size),
ctx = CTX, grad.req = "write")
# Initial parameters
mx.set.seed(0)
M2P_gen_arg <- mxnet:::mx.model.init.params(symbol = M2P_gen,
input.shape = list(M2P_img = c(64, 64, 3, Batch_size)),
output.shape = NULL,
initializer = mxnet:::mx.init.normal(0.02),
ctx = CTX)
P2M_gen_arg <- mxnet:::mx.model.init.params(symbol = P2M_gen,
input.shape = list(P2M_img = c(64, 64, 3, Batch_size)),
output.shape = NULL,
initializer = mxnet:::mx.init.normal(0.02),
ctx = CTX)
Monet_dis_arg <- mxnet:::mx.model.init.params(symbol = Monet_loss,
input.shape = list(Monet_img = c(64, 64, 3, Batch_size), label = c(Batch_size)),
output.shape = NULL,
initializer = mxnet:::mx.init.normal(0.02),
ctx = CTX)
Photo_dis_arg <- mxnet:::mx.model.init.params(symbol = Photo_loss,
input.shape = list(Photo_img = c(64, 64, 3, Batch_size), label = c(Batch_size)),
output.shape = NULL,
initializer = mxnet:::mx.init.normal(0.02),
ctx = CTX)
# Update parameters
mx.exec.update.arg.arrays(M2P_gen_executor, M2P_gen_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(M2P_gen_executor, M2P_gen_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(P2M_gen_executor, P2M_gen_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(P2M_gen_executor, P2M_gen_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(Monet_dis_executor, Monet_dis_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(Monet_dis_executor, Monet_dis_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(Photo_dis_executor, Photo_dis_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(Photo_dis_executor, Photo_dis_arg$aux.params, match.name = TRUE)
# Optimizers
M2P_gen_optimizer <- mx.opt.create(name = "adam", learning.rate = learning_rate, beta1 = 0, beta2 = 0.9, wd = 0)
P2M_gen_optimizer <- mx.opt.create(name = "adam", learning.rate = learning_rate, beta1 = 0, beta2 = 0.9, wd = 0)
Monet_dis_optimizer <- mx.opt.create(name = "adam", learning.rate = learning_rate, beta1 = 0, beta2 = 0.9, wd = 0)
Photo_dis_optimizer <- mx.opt.create(name = "adam", learning.rate = learning_rate, beta1 = 0, beta2 = 0.9, wd = 0)
# Updaters
M2P_gen_updater <- mx.opt.get.updater(optimizer = M2P_gen_optimizer, weights = M2P_gen_executor$ref.arg.arrays)
P2M_gen_updater <- mx.opt.get.updater(optimizer = P2M_gen_optimizer, weights = P2M_gen_executor$ref.arg.arrays)
Monet_dis_updater <- mx.opt.get.updater(optimizer = Monet_dis_optimizer, weights = Monet_dis_executor$ref.arg.arrays)
Photo_dis_updater <- mx.opt.get.updater(optimizer = Photo_dis_optimizer, weights = Photo_dis_executor$ref.arg.arrays)
# Start to train
for (j in 1:n.epoch) {
current_batch <- 0
t0 <- Sys.time()
my_iter$reset()
batch_logger <- list(Monet_adversarial_loss.gen = NULL,
Monet_adversarial_loss.fake = NULL,
Monet_adversarial_loss.real = NULL,
Photo_adversarial_loss.gen = NULL,
Photo_adversarial_loss.fake = NULL,
Photo_adversarial_loss.real = NULL,
Monet_cycle_consistency_loss = NULL,
Photo_cycle_consistency_loss = NULL,
Monet_identity_mapping_loss = NULL,
Photo_identity_mapping_loss = NULL)
while (my_iter$iter.next()) {
my_values <- my_iter$value()
##################################
# #
# Cycle consistency loss (Part1) #
# #
##################################
# Generator-1 forward (real Monet to fake Photo)
mx.exec.update.arg.arrays(M2P_gen_executor, arg.arrays = list(M2P_img = my_values[['monet']]), match.name = TRUE)
mx.exec.forward(M2P_gen_executor, is.train = TRUE)
fake.Photo_output <- M2P_gen_executor$ref.outputs[[1]]
fake.Photo_img <- as.array(fake.Photo_output)
# Generator-2 forward (fake Photo to restored Monet)
mx.exec.update.arg.arrays(P2M_gen_executor, arg.arrays = list(P2M_img = fake.Photo_output), match.name = TRUE)
mx.exec.forward(P2M_gen_executor, is.train = TRUE)
restored.Monet_output <- P2M_gen_executor$ref.outputs[[1]]
restored.Monet_img <- as.array(restored.Monet_output)
# Cycle consistency loss (Monet)
mx.exec.update.arg.arrays(cycle_consistency_executor, arg.arrays = list(pred = restored.Monet_output, label = my_values[['monet']]), match.name = TRUE)
mx.exec.forward(cycle_consistency_executor, is.train = TRUE)
mx.exec.backward(cycle_consistency_executor)
batch_logger$Monet_cycle_consistency_loss <- c(batch_logger$Monet_cycle_consistency_loss, as.array(cycle_consistency_executor$ref.outputs[[1]]))
# Generator-2 backward
P2M_grads <- cycle_consistency_executor$ref.grad.arrays[['pred']]
mx.exec.backward(P2M_gen_executor, out_grads = P2M_grads)
P2M_gen_update_args <- P2M_gen_updater(weight = P2M_gen_executor$ref.arg.arrays, grad = P2M_gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(P2M_gen_executor, P2M_gen_update_args, skip.null = TRUE)
# Generator-1 backward
M2P_grads <- P2M_gen_executor$ref.grad.arrays[['P2M_img']]
mx.exec.backward(M2P_gen_executor, out_grads = M2P_grads)
M2P_gen_update_args <- M2P_gen_updater(weight = M2P_gen_executor$ref.arg.arrays, grad = M2P_gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(M2P_gen_executor, M2P_gen_update_args, skip.null = TRUE)
#################################
# #
# Identity mapping loss (Part1) #
# #
#################################
# Generator-1 forward (real Photo to fake Photo)
mx.exec.update.arg.arrays(M2P_gen_executor, arg.arrays = list(M2P_img = my_values[['photo']]), match.name = TRUE)
mx.exec.forward(M2P_gen_executor, is.train = TRUE)
mirror.Photo_output <- M2P_gen_executor$ref.outputs[[1]]
mirror.Photo_img <- as.array(mirror.Photo_output)
# Identity mapping loss (Photo)
mx.exec.update.arg.arrays(identity_mapping_executor, arg.arrays = list(pred = mirror.Photo_output, label = my_values[['photo']]), match.name = TRUE)
mx.exec.forward(identity_mapping_executor, is.train = TRUE)
mx.exec.backward(identity_mapping_executor)
batch_logger$Photo_identity_mapping_loss <- c(batch_logger$Photo_identity_mapping_loss, as.array(identity_mapping_executor$ref.outputs[[1]]))
# Generator-1 backward
M2P_grads <- identity_mapping_executor$ref.grad.arrays[['pred']]
mx.exec.backward(M2P_gen_executor, out_grads = M2P_grads)
M2P_gen_update_args <- M2P_gen_updater(weight = M2P_gen_executor$ref.arg.arrays, grad = M2P_gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(M2P_gen_executor, M2P_gen_update_args, skip.null = TRUE)
############################
# #
# Adversarial loss (Part1) #
# #
############################
# Generator-1 forward (real Monet to fake Photo)
mx.exec.update.arg.arrays(M2P_gen_executor, arg.arrays = list(M2P_img = my_values[['monet']]), match.name = TRUE)
mx.exec.forward(M2P_gen_executor, is.train = TRUE)
fake.Photo_output <- M2P_gen_executor$ref.outputs[[1]]
# Discriminator-2 fake (Photo)
mx.exec.update.arg.arrays(Photo_dis_executor, arg.arrays = list(Photo_img = fake.Photo_output, label = mx.nd.array(rep(1, Batch_size))), match.name = TRUE)
mx.exec.forward(Photo_dis_executor, is.train = TRUE)
mx.exec.backward(Photo_dis_executor)
Photo_dis_update_args <- Photo_dis_updater(weight = Photo_dis_executor$ref.arg.arrays, grad = Photo_dis_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(Photo_dis_executor, Photo_dis_update_args, skip.null = TRUE)
batch_logger$Photo_adversarial_loss.fake <- c(batch_logger$Photo_adversarial_loss.fake, as.array(Photo_dis_executor$ref.outputs[[1]]))
# Discriminator-2 real (Photo)
mx.exec.update.arg.arrays(Photo_dis_executor, arg.arrays = list(Photo_img = my_values[['photo']], label = mx.nd.array(rep(0, Batch_size))), match.name = TRUE)
mx.exec.forward(Photo_dis_executor, is.train = TRUE)
mx.exec.backward(Photo_dis_executor)
Photo_dis_update_args <- Photo_dis_updater(weight = Photo_dis_executor$ref.arg.arrays, grad = Photo_dis_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(Photo_dis_executor, Photo_dis_update_args, skip.null = TRUE)
batch_logger$Photo_adversarial_loss.real <- c(batch_logger$Photo_adversarial_loss.real, as.array(Photo_dis_executor$ref.outputs[[1]]))
# Adversarial loss (Photo)
mx.exec.update.arg.arrays(Photo_dis_executor, arg.arrays = list(Photo_img = fake.Photo_output, label = mx.nd.array(rep(0, Batch_size))), match.name = TRUE)
mx.exec.forward(Photo_dis_executor, is.train = TRUE)
mx.exec.backward(Photo_dis_executor)
batch_logger$Photo_adversarial_loss.gen <- c(batch_logger$Photo_adversarial_loss.gen, as.array(Photo_dis_executor$ref.outputs[[1]]))
# Generator-1 backward
M2P_grads <- Photo_dis_executor$ref.grad.arrays[['Photo_img']]
mx.exec.backward(M2P_gen_executor, out_grads = M2P_grads)
M2P_gen_update_args <- M2P_gen_updater(weight = M2P_gen_executor$ref.arg.arrays, grad = M2P_gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(M2P_gen_executor, M2P_gen_update_args, skip.null = TRUE)
# Weight clipping (Discriminator-2)
if (!is.null(w_limit)) {
dis_weight_names <- grep('weight', names(Photo_dis_executor$ref.arg.arrays), value = TRUE)
for (k in dis_weight_names) {
current_dis_weight <- Photo_dis_executor$ref.arg.arrays[[k]] %>% as.array()
current_dis_weight_list <- current_dis_weight %>% mx.nd.array() %>%
mx.nd.broadcast.minimum(., mx.nd.array(w_limit)) %>%
mx.nd.broadcast.maximum(., mx.nd.array(-w_limit)) %>%
list()
names(current_dis_weight_list) <- k
mx.exec.update.arg.arrays(Photo_dis_executor, arg.arrays = current_dis_weight_list, match.name = TRUE)
}
}
##################################
# #
# Cycle consistency loss (Part2) #
# #
##################################
# Generator-2 forward (real Photo to fake Monet)
mx.exec.update.arg.arrays(P2M_gen_executor, arg.arrays = list(P2M_img = my_values[['photo']]), match.name = TRUE)
mx.exec.forward(P2M_gen_executor, is.train = TRUE)
fake.Monet_output <- P2M_gen_executor$ref.outputs[[1]]
fake.Monet_img <- as.array(fake.Monet_output)
# Generator-1 forward (fake Monet to restored Photo)
mx.exec.update.arg.arrays(M2P_gen_executor, arg.arrays = list(M2P_img = fake.Monet_output), match.name = TRUE)
mx.exec.forward(M2P_gen_executor, is.train = TRUE)
restored.Photo_output <- M2P_gen_executor$ref.outputs[[1]]
restored.Photo_img <- as.array(restored.Photo_output)
# Cycle consistency loss (Photo)
mx.exec.update.arg.arrays(cycle_consistency_executor, arg.arrays = list(pred = restored.Photo_output, label = my_values[['photo']]), match.name = TRUE)
mx.exec.forward(cycle_consistency_executor, is.train = TRUE)
mx.exec.backward(cycle_consistency_executor)
batch_logger$Photo_cycle_consistency_loss <- c(batch_logger$Photo_cycle_consistency_loss, as.array(cycle_consistency_executor$ref.outputs[[1]]))
# Generator-1 backward
M2P_grads <- cycle_consistency_executor$ref.grad.arrays[['pred']]
mx.exec.backward(M2P_gen_executor, out_grads = M2P_grads)
M2P_gen_update_args <- M2P_gen_updater(weight = M2P_gen_executor$ref.arg.arrays, grad = M2P_gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(M2P_gen_executor, M2P_gen_update_args, skip.null = TRUE)
# Generator-2 backward
P2M_grads <- M2P_gen_executor$ref.grad.arrays[['M2P_img']]
mx.exec.backward(P2M_gen_executor, out_grads = P2M_grads)
P2M_gen_update_args <- P2M_gen_updater(weight = P2M_gen_executor$ref.arg.arrays, grad = P2M_gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(P2M_gen_executor, P2M_gen_update_args, skip.null = TRUE)
#################################
# #
# Identity mapping loss (Part2) #
# #
#################################
# Generator-2 forward (real Monet to fake Monet)
mx.exec.update.arg.arrays(P2M_gen_executor, arg.arrays = list(P2M_img = my_values[['monet']]), match.name = TRUE)
mx.exec.forward(P2M_gen_executor, is.train = TRUE)
mirror.Monet_output <- P2M_gen_executor$ref.outputs[[1]]
mirror.Monet_img <- as.array(mirror.Monet_output)
# Identity mapping loss (Monet)
mx.exec.update.arg.arrays(identity_mapping_executor, arg.arrays = list(pred = mirror.Monet_output, label = my_values[['monet']]), match.name = TRUE)
mx.exec.forward(identity_mapping_executor, is.train = TRUE)
mx.exec.backward(identity_mapping_executor)
batch_logger$Monet_identity_mapping_loss <- c(batch_logger$Monet_identity_mapping_loss, as.array(identity_mapping_executor$ref.outputs[[1]]))
# Generator-2 backward
P2M_grads <- identity_mapping_executor$ref.grad.arrays[['pred']]
mx.exec.backward(P2M_gen_executor, out_grads = P2M_grads)
P2M_gen_update_args <- P2M_gen_updater(weight = P2M_gen_executor$ref.arg.arrays, grad = P2M_gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(P2M_gen_executor, P2M_gen_update_args, skip.null = TRUE)
############################
# #
# Adversarial loss (Part2) #
# #
############################
# Generator-2 forward (real Photo to fake Monet)
mx.exec.update.arg.arrays(P2M_gen_executor, arg.arrays = list(P2M_img = my_values[['photo']]), match.name = TRUE)
mx.exec.forward(P2M_gen_executor, is.train = TRUE)
fake.Monet_output <- P2M_gen_executor$ref.outputs[[1]]
# Discriminator-1 fake (Monet)
mx.exec.update.arg.arrays(Monet_dis_executor, arg.arrays = list(Monet_img = fake.Monet_output, label = mx.nd.array(rep(1, Batch_size))), match.name = TRUE)
mx.exec.forward(Monet_dis_executor, is.train = TRUE)
mx.exec.backward(Monet_dis_executor)
Monet_dis_update_args <- Monet_dis_updater(weight = Monet_dis_executor$ref.arg.arrays, grad = Monet_dis_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(Monet_dis_executor, Monet_dis_update_args, skip.null = TRUE)
batch_logger$Monet_adversarial_loss.fake <- c(batch_logger$Monet_adversarial_loss.fake, as.array(Monet_dis_executor$ref.outputs[[1]]))
# Discriminator-1 real (Monet)
mx.exec.update.arg.arrays(Monet_dis_executor, arg.arrays = list(Monet_img = my_values[['monet']], label = mx.nd.array(rep(0, Batch_size))), match.name = TRUE)
mx.exec.forward(Monet_dis_executor, is.train = TRUE)
mx.exec.backward(Monet_dis_executor)
Monet_dis_update_args <- Monet_dis_updater(weight = Monet_dis_executor$ref.arg.arrays, grad = Monet_dis_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(Monet_dis_executor, Monet_dis_update_args, skip.null = TRUE)
batch_logger$Monet_adversarial_loss.real <- c(batch_logger$Monet_adversarial_loss.real, as.array(Monet_dis_executor$ref.outputs[[1]]))
# Adversarial loss (Monet)
mx.exec.update.arg.arrays(Monet_dis_executor, arg.arrays = list(Monet_img = fake.Monet_output, label = mx.nd.array(rep(0, Batch_size))), match.name = TRUE)
mx.exec.forward(Monet_dis_executor, is.train = TRUE)
mx.exec.backward(Monet_dis_executor)
batch_logger$Monet_adversarial_loss.gen <- c(batch_logger$Monet_adversarial_loss.gen, as.array(Monet_dis_executor$ref.outputs[[1]]))
# Generator-2 backward
P2M_grads <- Monet_dis_executor$ref.grad.arrays[['Monet_img']]
mx.exec.backward(P2M_gen_executor, out_grads = P2M_grads)
P2M_gen_update_args <- P2M_gen_updater(weight = P2M_gen_executor$ref.arg.arrays, grad = P2M_gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(P2M_gen_executor, P2M_gen_update_args, skip.null = TRUE)
# Weight clipping (Discriminator-1)
if (!is.null(w_limit)) {
dis_weight_names <- grep('weight', names(Monet_dis_executor$ref.arg.arrays), value = TRUE)
for (k in dis_weight_names) {
current_dis_weight <- Monet_dis_executor$ref.arg.arrays[[k]] %>% as.array()
current_dis_weight_list <- current_dis_weight %>% mx.nd.array() %>%
mx.nd.broadcast.minimum(., mx.nd.array(w_limit)) %>%
mx.nd.broadcast.maximum(., mx.nd.array(-w_limit)) %>%
list()
names(current_dis_weight_list) <- k
mx.exec.update.arg.arrays(Monet_dis_executor, arg.arrays = current_dis_weight_list, match.name = TRUE)
}
}
############################
# #
# Show current performance #
# #
############################
if (current_batch %% n.print == 0) {
# Show current images
par(mfrow = c(num_show_img * 2, 4), mar = c(0.1, 0.1, 0.1, 0.1))
for (i in 1:num_show_img) {
Show_img(img = as.array(my_values[['monet']])[,,,i])
Show_img(img = as.array(fake.Photo_img)[,,,i])
Show_img(img = as.array(mirror.Monet_img)[,,,i])
Show_img(img = as.array(restored.Monet_img)[,,,i])
}
for (i in 1:num_show_img) {
Show_img(img = as.array(my_values[['photo']])[,,,i])
Show_img(img = as.array(fake.Monet_img)[,,,i])
Show_img(img = as.array(mirror.Photo_img)[,,,i])
Show_img(img = as.array(restored.Photo_img)[,,,i])
}
# Show speed
speed_per_batch <- as.numeric(Sys.time() - t0, units = 'secs') / (current_batch + 1)
# Show loss
current_loss <- batch_logger %>% sapply(., mean) %>% formatC(., 4, format = 'f')
message('Epoch [', j, '] Batch [', current_batch, '] loss list (', formatC(speed_per_batch, 2, format = 'f'), ' sec/batch):')
message(paste(paste(names(current_loss), current_loss, sep = ': '), collapse = '\n'))
}
current_batch <- current_batch + 1
}
# Save models
M2P_gen_model <- list()
M2P_gen_model$symbol <- M2P_gen
M2P_gen_model$arg.params <- M2P_gen_executor$ref.arg.arrays[-1]
M2P_gen_model$aux.params <- M2P_gen_executor$ref.aux.arrays
class(M2P_gen_model) <- "MXFeedForwardModel"
mx.model.save(model = M2P_gen_model, prefix = paste0('model/CycleGAN_', model_name, '/M2P_gen_', model_name), iteration = j)
P2M_gen_model <- list()
P2M_gen_model$symbol <- P2M_gen
P2M_gen_model$arg.params <- P2M_gen_executor$ref.arg.arrays[-1]
P2M_gen_model$aux.params <- P2M_gen_executor$ref.aux.arrays
class(P2M_gen_model) <- "MXFeedForwardModel"
mx.model.save(model = P2M_gen_model, prefix = paste0('model/CycleGAN_', model_name, '/P2M_gen_', model_name), iteration = j)
Monet_dis_model <- list()
Monet_dis_model$symbol <- Monet_dis
Monet_dis_model$arg.params <- Monet_dis_executor$ref.arg.arrays[-1]
Monet_dis_model$aux.params <- Monet_dis_executor$ref.aux.arrays
class(Monet_dis_model) <- "MXFeedForwardModel"
mx.model.save(model = Monet_dis_model, prefix = paste0('model/CycleGAN_', model_name, '/Monet_dis_', model_name), iteration = j)
Photo_dis_model <- list()
Photo_dis_model$symbol <- Photo_dis
Photo_dis_model$arg.params <- Photo_dis_executor$ref.arg.arrays[-1]
Photo_dis_model$aux.params <- Photo_dis_executor$ref.aux.arrays
class(Photo_dis_model) <- "MXFeedForwardModel"
mx.model.save(model = Photo_dis_model, prefix = paste0('model/CycleGAN_', model_name, '/Photo_dis_', model_name), iteration = j)
}
– 當然,想要真的訓練的很好,你可能需要加深加寬Model architecture,以及使用完整的資料集,而這樣需要極大的運算資源,你可能需要使用GPU server。
– 對抗生成網路的變化性非常的有趣,尤其像是CycleGAN的概念,我們的研究已經開始往無監督學習的方向前進了,這是讓機器更像人類非常重要的一步!
自2016年AlphoGO成功擊敗了頂尖的圍棋選手之後,「人工智慧」、「深度學習」這兩個名詞開始一炮而紅。而AI的捲土重來不僅僅是運算資源提升帶來的幫助,無論是在理論上及應用上都有著重大的突破,透過了一個學期的學習,我們應該能掌握其核心技術,並且在概念上可以想像的到他幾乎能做到任何事情!
我們回顧一下我們的課程,並檢視一下自己的學習效果。在理論方面,你應該已經非常了解它的本質就是預測函數,並且清楚的知道訓練中可能遭遇的困難及解決方式;在應用方面,幾個經典的實驗已經發展出了特殊的網路結構,並且在圖像識別、自然語言處理、生成模型上你應該都知道他的運作原理。
– 儘管還有一些任務我們沒有帶各位實現,但你應該能想像到「看圖說話」、「聊天對答」等是怎樣做到的了吧?